-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operators] Add batch support for x86 CPU matrix multiplication + resolve rule #415
Conversation
deleted redundant file
use the original name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @BolinSNLHM !
Glad to see that you extend it to batch version. But there are still something we can do to make it better. See the comments below.
), | ||
) | ||
|
||
super().__init__( | ||
name='matmul_f32_x86', | ||
inputs=[a, b], | ||
outputs=[c], | ||
attributes={'m_size': a_shape[-2], 'n_size': b_shape[-1], 'k_size': a_shape[-1]}, | ||
attributes={'batch_size': batch_size, 'm_size': m_size, 'n_size': n_size, 'k_size': k_size}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use this function https://github.com/hidet-org/hidet/blob/main/python/hidet/ir/compute/cops/matmul.py#L30 to add computation definition for matmul.
Like
c = cops.matmul(a, b, allow_1d=True) |
and (not is_constant(a.shape[0], b.shape[0]) or a.shape[0] == b.shape[0]) | ||
and (not is_constant(a.shape[2], b.shape[1]) or a.shape[2] == b.shape[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you require the shape of a and b constant?
Is it possible to support dynamic shape like https://github.com/hidet-org/hidet/blob/main/python/hidet/graph/ops/matmul/batch_matmul.py
# if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]): | ||
# raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape)) | ||
if not ( | ||
len(a.shape) == len(b.shape) == 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use the same template to support matmul like:
- [12, 1024, 1024] @ [1024, 1024]
- [12, 1024, 1024] @ [1, 1024, 1024]
- [1024, 1024] @ [4, 5, 1024, 1024]
You can have a look at https://github.com/hidet-org/hidet/blob/main/python/hidet/graph/ops/matmul/batch_matmul.py as a reference.
Co-authored-by: Yaoyao Ding <[email protected]>
Co-authored-by: Yaoyao Ding <[email protected]>
Co-authored-by: Yaoyao Ding <[email protected]>
Co-authored-by: Yaoyao Ding <[email protected]>
Co-authored-by: Yaoyao Ding <[email protected]>
Co-authored-by: Yaoyao Ding <[email protected]>
Co-authored-by: Yaoyao Ding <[email protected]>
Preliminary performance testing results compared to PyTorch: